def plot_all(varient=''):
all_pdfs = []
all_labels = []
all_pdfs_noise = []
all_labels_noise = []
varient= str(varient)
x = jnp.linspace(0,6,10000)
# with open('./results_data/linear_regression_noise_Ajax'+varient,'rb') as f:
# variational = pickle.load(f)
# params = variational.get_params()
# loc_m, scale = jax.tree_leaves(variational.transform_dist(params['theta']))
# scale = jnp.dot(scale, scale.T)
# for i in range(2):
# y = tfd.Normal(loc = loc_m[i],scale = jnp.sqrt(scale[i][i])).prob(x)
# all_pdfs.append(y)
# all_labels.append('Ajax VI theta0')
# all_labels.append('Ajax VI theta1')
# with open('./results_data/linear_regression_laplace'+varient,'rb') as f:
# laplace = pickle.load(f)
# loc_m = laplace['mean']
# std = jnp.sqrt(jnp.diag(laplace['cov']))
# for i in range(2):
# y = tfd.Normal(loc = loc_m[i],scale = std[i]).prob(x)
# all_pdfs.append(y)
# all_labels.append('Laplace approximation theta0')
# all_labels.append('Laplace approximation theta1')
with open('./results_data/MCMC_Blackjax'+varient,'rb') as f:
black_samples = pickle.load(f)
for i in range(2):
kde_black = gaussian_kde(black_samples.position['theta'][:,i])
pdf_black = kde_black(x)
all_pdfs.append(pdf_black)
kde_black = gaussian_kde(black_samples.position['noise_var'])
pdf_black = kde_black(x)
all_pdfs_noise.append(pdf_black)
all_labels.append('Blackjax rmh theta0')
all_labels.append( 'Blackjax rmh theta1')
all_labels_noise.append("Blackjax rmh noise")
with open("./results_data/ajax_model"+varient,'rb') as f:
posterior = pickle.load(f)
samples_ajax= posterior.sample(seed = jax.random.PRNGKey(10), sample_shape = (10000,))
for i in range(2):
kde_ajax = gaussian_kde(samples_ajax["theta"][:,i])
pdf_ajax = kde_ajax(x)
all_pdfs.append(pdf_ajax)
kde_ajax = gaussian_kde(samples_ajax["noise"])
pdf_ajax = kde_ajax(x)
all_pdfs_noise.append(pdf_ajax)
all_labels.append("Ajax VI theta0")
all_labels.append("Ajax VI theta1")
all_labels_noise.append("Ajax VI noise")
def create_df(all_pdfs,all_labels,x):
all_pdfs = jnp.array(all_pdfs).reshape((-1))
no_estimates = len(all_labels)
all_labels_repeated = [item for item in all_labels for i in range(x.shape[0])]
x_repeated = jnp.tile(x,no_estimates)
to_df = {
"theta":x_repeated,
"PDF":all_pdfs,
"label": all_labels_repeated
}
df = pd.DataFrame(to_df)
return df
df = create_df(all_pdfs,all_labels,x)
fig = px.line(df,"theta","PDF",color="label",title=f"Linear regression posterior")
fig.show()
df = create_df(all_pdfs_noise,all_labels_noise,x)
fig = px.line(df,"theta","PDF",color="label",title=f"Linear regression posterior")
fig.show()